import numpy as np
import scipy.special as sp

def calShannonCap(sbr, mod, snrList=[20]):
	sbr = np.array(sbr)
	sbr /= np.linalg.norm(sbr)
	mainCursor = np.max(sbr)

	berList = []
	for snr in snrList:
		if mod == "nrz":
			sigma = 10**(-(float(snr))/20)
			ber = sp.erfc(mainCursor/sigma/np.sqrt(2))*0.5
		elif mod == "pam4":
			sigma = 10**(-(float(snr)-np.log10(1/3.))/20)

		if 0:
			print (f"sbr: {sbr}")
			print (f"sum(sbr^2): {np.sum(sbr**2)}")
			print (f"mainCursor: {mainCursor}")
			print (f"sigma: {sigma}")
		
		berList.append(ber)

	return berList

def calShannonCap2(sbr, mod, snrList=[20]):
	sbr = np.array(sbr)
	sbr /= np.linalg.norm(sbr)
	mainCursor = np.max(sbr)

	berList = []
	for snr in snrList:
		if mod == "nrz":
			sigma = 10**(-(float(snr))/20)
			#ber = sp.erfc(mainCursor/sigma/np.sqrt(2))*0.5
			q = sp.erfc(mainCursor/sigma/np.sqrt(2))*0.5
			hxy = -(1-q)*np.log2(1-q) - q*np.log2(q)
			ber = hxy/1
		elif mod == "pam4":
			sigma = 10**(-(float(snr)-np.log10(1/3.))/20)

		if 0:
			print (f"sbr: {sbr}")
			print (f"sum(sbr^2): {np.sum(sbr**2)}")
			print (f"mainCursor: {mainCursor}")
			print (f"sigma: {sigma}")
		
		berList.append(ber)

	return berList


if __name__ == "__main__":
	sbr = [1.0]
	#sbr = [1.0, 0.0]
	berList = calShannonCap(sbr, 'nrz', snrList=np.arange(15,10,-0.5))
	print(berList)

	berList2 = calShannonCap2(sbr, 'nrz', snrList=np.arange(15,10,-0.5))
	print(berList2)